"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script regroups the results of the clustering of the APS dataset then used
to draw Figs 8 9 and S5.

saves the results as `multi_complenet_partitions_all_comps.pickle`, 
                     `multi_complenet_partitions_all_comps_active_nodes_sorted.pickle`
                     and `aps_flow_stability_clustering.json`.

"""
import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))



import numpy as np
import pandas as pd
from TemporalNetwork import ContTempNetwork
from TemporalStability import Partition


import matplotlib.pyplot as plt


plt.style.use('alex_paper')


datadir = '../paper_data/aps/'

country = 'complenet'

clusterdir_fwd = os.path.join(datadir, 'clustersI_complenet_forward')

figdir = '../figures/aps/complenet'

os.makedirs(figdir, exist_ok=True)

file_prefix = 'aps_monthly'

net_file = os.path.join(datadir,'aps_monthly_lcc_net')


raise Exception

#%%

files = os.listdir(clusterdir_fwd)

interval_starts = set()
interval_ends = set()
tau_ws = set()
for file in files:
    if file.startswith('clusters'):
        extracts = os.path.splitext(file)[0].split('_')
        interval_starts.add(int(extracts[-3]))
        interval_ends.add(int(extracts[-1]))
        for extract in os.path.splitext(file)[0].split('_'):
            if extract.startswith('w'):
                tau_ws.add(float(extract[1:]))
                
                
            
#%% net work file

net = ContTempNetwork.load(net_file, attributes_list = ['node_to_label_dict',
                          'events_table',
                          'times',
                          'time_grid',
                          'num_nodes',
                          'time_slices_bounds',
                          'time_slices_bounds_datetimes'])


    

time_slices = net.time_slices_bounds

label_to_node_dict = {l : n for n, l in net.node_to_label_dict.items()}

#%% size of the network per year

active_nodes_per_year = dict()

for i, (start, end) in enumerate(zip(time_slices[:-1],time_slices[1:])):
    print(i, start, end)

    source_nodes = net.events_table.loc[np.logical_and(net.events_table.starting_times < end, net.events_table.ending_times >= start)].source_nodes.unique()
    target_nodes = net.events_table.loc[np.logical_and(net.events_table.starting_times < end, net.events_table.ending_times >= start)].target_nodes.unique()

    active_nodes_per_year[i] = set(source_nodes) | set(target_nodes)
    
active_nodes_per_year_cummul = dict()
nodes = set()
for i, node_set in active_nodes_per_year.items():
    
    nodes.update(node_set)
    active_nodes_per_year_cummul[i] = nodes.copy()
    
active_nodes_per_decade = {}

decades = {n:r for n,r in  zip(['70s','80s','90s','00s'],[range(78,88),range(88,98),range(98,108),range(108,118)])}
nodes = set()
for dec, r in decades.items():
    nodes = set()
    for i in r:
        nodes.update(active_nodes_per_year[i])
    
    active_nodes_per_decade[dec] = nodes.copy()
#%% get multi res


multi_res = {}

taus_to_plot = sorted(list(tau_ws))

interval_pairs = [[118, 108], [108, 98], [98, 88], [88, 78]]

for tau_w in taus_to_plot:
    print(tau_w)
    
    multi_res[tau_w] = {}
        
    # here 'forw' actually corresponds to 'backward' in the paper, i.e. in the direction of the time-reversed diffusion process  
    multi_res[tau_w]['partitions_forw'] = []
    multi_res[tau_w]['nvarinfs_forw'] = []

    
    best_partrev = None
    
    for int_start, int_stop in interval_pairs:
        
        

        clustres = pd.read_pickle(os.path.join(clusterdir_fwd, 'clusters_' + \
                                      file_prefix + '_tau_w' + \
                                      '{0:.3e}_PTforw_{1:06d}_to_{2:06d}.hkl'.format(tau_w, 
                                                                                     int_start,
                                                                                     int_stop)))

                
        clist_forw = clustres['best_cluster_sym']
        
        # partitions with only active nodes
        best_part = Partition(sum([len(c) for c in clist_forw]), sorted(clist_forw,
                                                        key=lambda x: len(x), reverse=True))

        
        multi_res[tau_w]['partitions_forw'].append(best_part)
        
        multi_res[tau_w]['nvarinfs_forw'].append(clustres['nvarinf_sym'])
        

#%%

pd.to_pickle(multi_res, os.path.join(datadir, 'multi_complenet_partitions_all_comps.pickle'))

#%% intersect with active nodes

new_multi_res = {}
for tau_w in tau_ws:
    new_multi_res[tau_w] = {}    
    for dire in multi_res[tau_w].keys():
        new_multi_res[tau_w]['partitions_' + dire] = []
        for part, dec in zip(multi_res[tau_w]['partitions_' + dire],
                                   ['00s','90s','80s','70s']):
            
            # intersection and sort with largest clusters first
            new_c_list = sorted([cl & active_nodes_per_decade[dec] for cl in part.cluster_list],
                                key=len, reverse=True)
            
            new_multi_res[tau_w]['partitions_' + dire].append(Partition(sum([len(c) for c in new_c_list]),
                                                                        new_c_list))

#%%

pd.to_pickle(new_multi_res, os.path.join(datadir, 'multi_complenet_partitions_all_comps_active_nodes_sorted.pickle'))

#%% save results to share
tau_w = 3650.0

# relabel nodes 
partitions_back_per_decade = []
for p in new_multi_res[tau_w]['partitions_forw']: # forw here is back in the paper
    partitions_back_per_decade.append([[net.node_to_label_dict[n] for n in clust] for clust in p.cluster_list])

res_share = [{'tau_w': tau_w,
              'partitions_back_per_decade': partitions_back_per_decade,
              'NVI_forward_per_decade': multi_res[tau_w]['nvarinfs_forw'],
              'decades': ['2000 to 2010', '1990 to 2000', 
                          '1980 to 1990', '1970 to 1980'],
              'info':'the node IDs correspond to the IDs of the name disambiguation from  R. Sinatra et al. Science 354, aaf5239–aaf5239 (2016). The list is given in the SM of this article.'},]

import json

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, set):
            return list(obj)
        return super(NpEncoder, self).default(obj)


with open('../data/aps_flow_stability_clustering.json', 'w') as fopen:
    json.dump(res_share, fopen, cls=NpEncoder)

